Skip to content

Conversation

@djsaunde
Copy link
Member

@djsaunde djsaunde commented Sep 29, 2025

Description

Title.

Motivation and Context

Pretrained autoregressive models treat the output logits as right-shifted by one. By doing this, we should be able to use pretrained AR models effectively for diffusion model fine-tuning!

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • Bug Fixes
    • Corrects token alignment during diffusion generation and training so predictions match the corresponding input positions. This improves sampling stability and reduces artifacts in outputs.
    • Enhances loss calculation accuracy by aligning model predictions with inputs, leading to more consistent training behavior and better-quality results during inference.

@djsaunde djsaunde requested a review from a team September 29, 2025 15:45
@djsaunde djsaunde self-assigned this Sep 29, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Introduces a new utility to shift logits to input positions and updates diffusion generation and training to apply this shift before token selection and loss computation. No public API signatures changed.

Changes

Cohort / File(s) Summary
Logits alignment integration
src/axolotl/integrations/diffusion/generation.py, src/axolotl/integrations/diffusion/trainer.py
Import and apply shift_logits_to_input_positions to outputs.logits before sampling (generation) and before loss computation (trainer), replacing direct logits usage.
Utility addition
src/axolotl/integrations/diffusion/utils.py
Add shift_logits_to_input_positions(logits) function to realign next-token logits to input token positions by shifting along the sequence dimension.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes the primary change of shifting output logits to align with input tokens in the diffusion trainer. It is concise and clear, avoiding unnecessary details like file lists or emojis, and directly communicates the key fix introduced by this PR. Although the generation logic is also updated, the central focus on training behavior makes the title sufficiently representative of the main change.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/integrations/diffusion/utils.py (1)

162-166: Implementation looks correct for logits alignment.

The function correctly shifts next-token prediction logits to align with input token positions by:

  1. Preserving the first logit position unchanged
  2. Shifting remaining logits left by one position
  3. Properly handling edge case of single-token sequences

The implementation aligns with the PR objective of adapting pretrained autoregressive models for diffusion fine-tuning.

However, consider adding a brief example in the docstring to clarify the transformation:

-    """Align next-token logits with their input token positions for diffusion."""
+    """Align next-token logits with their input token positions for diffusion.
+    
+    Example: [logit_for_pos1, logit_for_pos2, logit_for_pos3] 
+    becomes: [logit_for_pos1, logit_for_pos1, logit_for_pos2]
+    """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f437674 and da80beb.

📒 Files selected for processing (3)
  • src/axolotl/integrations/diffusion/generation.py (2 hunks)
  • src/axolotl/integrations/diffusion/trainer.py (2 hunks)
  • src/axolotl/integrations/diffusion/utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/integrations/diffusion/trainer.py (1)
src/axolotl/integrations/diffusion/utils.py (2)
  • create_bidirectional_attention_mask (125-159)
  • shift_logits_to_input_positions (162-166)
src/axolotl/integrations/diffusion/generation.py (1)
src/axolotl/integrations/diffusion/utils.py (2)
  • create_bidirectional_attention_mask (125-159)
  • shift_logits_to_input_positions (162-166)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (4)
src/axolotl/integrations/diffusion/trainer.py (2)

14-14: LGTM: Import addition is correct.

The import of shift_logits_to_input_positions from the utils module follows the existing import pattern.


210-210: No issues detected with logits shifting alignment or loss computation.

src/axolotl/integrations/diffusion/generation.py (2)

10-10: LGTM: Import addition is consistent.

The import follows the same pattern as in the trainer module for consistency.


363-363: Logits shifting applied consistently in generation.

The shift is correctly applied before token sampling in the diffusion step, maintaining consistency with the training logic. This ensures the same logits alignment is used during both training and generation.

@codecov
Copy link

codecov bot commented Sep 29, 2025

Codecov Report

❌ Patch coverage is 75.00000% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/integrations/diffusion/generation.py 50.00% 1 Missing ⚠️
src/axolotl/integrations/diffusion/utils.py 75.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test fail unrelated: out of space

"""Align next-token logits with their input token positions for diffusion."""
if logits.size(1) <= 1:
return logits
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this trying to do? Concat logit's first column and 1..N column together?

Copy link
Member Author

@djsaunde djsaunde Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a hack to use pretrained causal LMs for diffusion fine-tuning. we're shifting logits to the right by one position so we align the input logits with the output logits

Unfortunately we're duplicating the first token, but I couldn't think of a better way to do it. open to ideas here

@winglian winglian force-pushed the diffusion-shift-logits branch from da80beb to 7f6f08e Compare October 7, 2025 20:49
@winglian
Copy link
Collaborator

winglian commented Oct 7, 2025

just as a data point, I had messed around with an early version of Dan's diffusion trainer a while back and here's the change I made to support next-token prediction cf8c93e. my changes may be unnecessary, but wanted to make sure we didn't miss anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants